{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Feature attribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Christopher Potts\"\n",
    "__version__ = \"CS224u, Stanford, Spring 2022\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Contents\n",
    "\n",
    "1. [Overview](#Overview)\n",
    "1. [InputXGradients](#InputXGradients)\n",
    "1. [Selectivity examples](#Selectivity-examples)\n",
    "1. [Simple feed-forward classifier example](#Simple-feed-forward-classifier-example)\n",
    "1. [Bag-of-words classifier for the SST](#Bag-of-words-classifier-for-the-SST)\n",
    "1. [BERT example](#BERT-example)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "This notebook is an experimental extension of the CS224u course code. It focuses on the [Integrated Gradients](https://arxiv.org/abs/1703.01365) method for feature attribution, with comparisons to the \"inputs $\\times$ gradients\" method. To run the notebook, first install [the Captum library](https://captum.ai/):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: captum in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (0.5.0)\n",
      "Requirement already satisfied: matplotlib in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from captum) (3.4.3)\n",
      "Requirement already satisfied: torch>=1.6 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from captum) (1.10.0)\n",
      "Requirement already satisfied: numpy in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from captum) (1.20.3)\n",
      "Requirement already satisfied: typing-extensions in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from torch>=1.6->captum) (3.10.0.2)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (8.4.0)\n",
      "Requirement already satisfied: cycler>=0.10 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (0.10.0)\n",
      "Requirement already satisfied: python-dateutil>=2.7 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (2.8.2)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (3.0.4)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from matplotlib->captum) (1.3.1)\n",
      "Requirement already satisfied: six in /Applications/anaconda3/envs/nlu/lib/python3.9/site-packages (from cycler>=0.10->matplotlib->captum) (1.16.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install captum"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is not currently a required installation (but it will be in future years)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## InputXGradients\n",
    "\n",
    "For both implementations, the `forward` method of `model` is used. `X` is an (m x n) tensor of attributions. Use `targets=None` for models with scalar outputs, else supply a LongTensor giving a label for each example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def grad_x_input(model, X, targets=None):\n",
    "    \"\"\"Implementation using PyTorch directly.\"\"\"\n",
    "    X.requires_grad = True\n",
    "    y = model(X)\n",
    "    y = y if targets is None else y[list(range(len(y))), targets]\n",
    "    (grads, ) = torch.autograd.grad(y.unbind(), X)\n",
    "    return grads * X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from captum.attr import InputXGradient\n",
    "\n",
    "def captum_grad_x_input(model, X, target):\n",
    "    \"\"\"Captum-based implementation.\"\"\"\n",
    "    X.requires_grad = True\n",
    "    amod = InputXGradient(model)\n",
    "    return amod.attribute(X, target=target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Selectivity examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from captum.attr import IntegratedGradients\n",
    "from captum.attr import InputXGradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SelectivityAssessor(nn.Module):\n",
    "    \"\"\"Model used by Sundararajan et al, section 2.1 to show that\n",
    "    input * gradients violates their selectivity axiom.\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, X):\n",
    "        return 1.0 - self.relu(1.0 - X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_mod = SelectivityAssessor()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple inputs with just one feature:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sel = torch.FloatTensor([[0.0], [2.0]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The outputs for our two examples differ:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [1.]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sel_mod(X_sel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "However, `InputXGradient` assigns the same importance to the feature across the two examples, violating selectivity:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [-0.]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "captum_grad_x_input(sel_mod, X_sel, target=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Integrated gradients addresses the problem by averaging gradients across all interpolated representations between the baseline and the actual input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "ig_sel = IntegratedGradients(sel_mod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_baseline = torch.FloatTensor([[0.0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.],\n",
       "        [1.]], dtype=torch.float64, grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ig_sel.attribute(X_sel, sel_baseline)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A toy implementation to help bring out what is happening:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ig_reference_implementation(model, x, base, m=50):\n",
    "    vals = []\n",
    "    for k in range(m):\n",
    "        # Interpolated representation:\n",
    "        xx = (base + (k/m)) * (x - base)\n",
    "        # Gradient for the interpolated example:\n",
    "        xx.requires_grad = True\n",
    "        y = model(xx)\n",
    "        (grads, ) = torch.autograd.grad(y.unbind(), xx)\n",
    "        vals.append(grads)\n",
    "    return (1 / m) * torch.cat(vals).sum(axis=0) * (x - base)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ig_reference_implementation(sel_mod, torch.FloatTensor([[2.0]]), sel_baseline)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simple feed-forward classifier example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from captum.attr import IntegratedGradients\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.feature_selection import mutual_info_classif\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cls, y_cls = make_classification(\n",
    "    n_samples=5000,\n",
    "    n_classes=3,\n",
    "    n_features=5,\n",
    "    n_informative=3,\n",
    "    n_redundant=0,\n",
    "    random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The classification problem has two uninformative features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.20138107, 0.02833358, 0.11584416, 0.        , 0.        ])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mutual_info_classif(X_cls, y_cls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_cls_train, X_cls_test, y_cls_train, y_cls_test = train_test_split(X_cls, y_cls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = TorchShallowNeuralClassifier()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 449. Training loss did not improve more than tol=1e-05. Final error is 1.3419027030467987."
     ]
    }
   ],
   "source": [
    "_ = classifier.fit(X_cls_train, y_cls_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "cls_preds = classifier.predict(X_cls_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8568"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_cls_test, cls_preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_ig = IntegratedGradients(classifier.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_baseline = torch.zeros(1, X_cls_train.shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Integrated gradients with respect to the actual labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier_attrs = classifier_ig.attribute(\n",
    "    torch.FloatTensor(X_cls_test),\n",
    "    classifier_baseline,\n",
    "    target=torch.LongTensor(y_cls_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Average attribution is low for the two uninformative features:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.6544,  0.6739,  0.7057, -0.0173, -0.0059], dtype=torch.float64)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifier_attrs.mean(axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bag-of-words classifier for the SST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "from captum.attr import IntegratedGradients\n",
    "from nltk.corpus import stopwords\n",
    "from operator import itemgetter\n",
    "import os\n",
    "from sklearn.metrics import classification_report\n",
    "import torch\n",
    "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
    "import sst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "SST_HOME = os.path.join(\"data\", \"sentiment\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Bag-of-word featurization with stopword removal to make this a little easier to study:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "stopwords = set(stopwords.words('english'))\n",
    "\n",
    "def phi(text):\n",
    "    return Counter([w for w in text.lower().split() if w not in stopwords])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_mlp(X, y):\n",
    "    mod = TorchShallowNeuralClassifier(early_stopping=True)\n",
    "    mod.fit(X, y)\n",
    "    return mod"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Stopping after epoch 24. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 1.3742991983890533"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "    negative      0.629     0.696     0.661       428\n",
      "     neutral      0.295     0.100     0.150       229\n",
      "    positive      0.625     0.773     0.691       444\n",
      "\n",
      "    accuracy                          0.603      1101\n",
      "   macro avg      0.516     0.523     0.500      1101\n",
      "weighted avg      0.558     0.603     0.567      1101\n",
      "\n"
     ]
    }
   ],
   "source": [
    "experiment = sst.experiment(\n",
    "    sst.train_reader(SST_HOME),\n",
    "    phi,\n",
    "    fit_mlp,\n",
    "    sst.dev_reader(SST_HOME))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Trained model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_classifier = experiment['model']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Captum needs to have labels as indices rather than strings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['negative', 'neutral', 'positive']"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sst_classifier.classes_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_sst_test = [sst_classifier.classes_.index(label)\n",
    "              for label in experiment['assess_datasets'][0]['y']]\n",
    "\n",
    "sst_preds = [sst_classifier.classes_.index(label)\n",
    "             for label in experiment['predictions'][0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our featurized test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sst_test = experiment['assess_datasets'][0]['X']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Feature names to help with analyses:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "fnames = experiment['train_dataset']['vectorizer'].get_feature_names()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Integrated gradients:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_ig = IntegratedGradients(sst_classifier.model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All-0s baseline:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_baseline = torch.zeros(1, experiment['train_dataset']['X'].shape[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Attributions with respect to the model's predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_attrs = sst_ig.attribute(\n",
    "    torch.FloatTensor(X_sst_test),\n",
    "    sst_baseline,\n",
    "    target=torch.LongTensor(sst_preds))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Helper functions for error analysis:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def error_analysis(gold=1, predicted=2):\n",
    "    err_ind = [i for i, (g, p) in enumerate(zip(y_sst_test, sst_preds))\n",
    "               if g == gold and p == predicted]\n",
    "    attr_lookup = create_attr_lookup(sst_attrs[err_ind])\n",
    "    return attr_lookup, err_ind\n",
    "\n",
    "def create_attr_lookup(attrs):\n",
    "    mu = attrs.mean(axis=0).detach().numpy()\n",
    "    return sorted(zip(fnames, mu), key=itemgetter(1), reverse=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "sst_attrs_lookup, sst_err_ind = error_analysis(gold=1, predicted=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(',', 0.04512846003808179),\n",
       " ('.', 0.03875377384651548),\n",
       " ('film', 0.036562292947638124),\n",
       " ('fun', 0.02995531556022619),\n",
       " ('best', 0.015621606617978723)]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sst_attrs_lookup[: 5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Error analysis for a specific example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "ex_ind = sst_err_ind[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'No one goes unindicted here , which is probably for the best .'"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "experiment['assess_datasets'][0]['raw_examples'][ex_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "ex_attr_lookup = create_attr_lookup(sst_attrs[ex_ind:ex_ind+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('best', 0.43364394640626314),\n",
       " (',', 0.04500691178712216),\n",
       " ('.', 0.03940604247146967),\n",
       " ('probably', 0.03321118433841792),\n",
       " ('one', 0.008722432294266332),\n",
       " ('goes', -0.03914730368530946)]"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(f, a) for f, a in ex_attr_lookup if a != 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BERT example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "from captum.attr import LayerIntegratedGradients\n",
    "from captum.attr import visualization as viz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_weights_name = 'cardiffnlp/twitter-roberta-base-sentiment'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_tokenizer = AutoTokenizer.from_pretrained(hf_weights_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "hf_model = AutoModelForSequenceClassification.from_pretrained(hf_weights_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hf_predict_one_proba(text):\n",
    "    input_ids = hf_tokenizer.encode(\n",
    "        text, add_special_tokens=True, return_tensors='pt')\n",
    "    hf_model.eval()\n",
    "    with torch.no_grad():\n",
    "        logits = hf_model(input_ids)[0]\n",
    "        preds = F.softmax(logits, dim=1)\n",
    "    hf_model.train()\n",
    "    return preds.squeeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hf_ig_encodings(text):\n",
    "    pad_id = hf_tokenizer.pad_token_id\n",
    "    cls_id = hf_tokenizer.cls_token_id\n",
    "    sep_id = hf_tokenizer.sep_token_id\n",
    "    input_ids = hf_tokenizer.encode(text, add_special_tokens=False)\n",
    "    base_ids = [pad_id] * len(input_ids)\n",
    "    input_ids = [cls_id] + input_ids + [sep_id]\n",
    "    base_ids = [cls_id] + base_ids + [sep_id]\n",
    "    return torch.LongTensor([input_ids]), torch.LongTensor([base_ids])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hf_ig_analyses(text2class):\n",
    "    data = []\n",
    "    for text, true_class in text2class.items():\n",
    "        score_vis = hf_ig_analysis_one(text, true_class)\n",
    "        data.append(score_vis)\n",
    "    viz.visualize_text(data)\n",
    "\n",
    "\n",
    "def hf_ig_analysis_one(text, true_class):\n",
    "    # Option to look at different layers:\n",
    "    # layer = model.roberta.encoder.layer[0]\n",
    "    # layer = model.roberta.embeddings.word_embeddings\n",
    "    layer = hf_model.roberta.embeddings\n",
    "\n",
    "    def ig_forward(inputs):\n",
    "        return hf_model(inputs).logits\n",
    "\n",
    "    ig = LayerIntegratedGradients(ig_forward, layer)\n",
    "\n",
    "    input_ids, base_ids = hf_ig_encodings(text)\n",
    "\n",
    "    attrs, delta = ig.attribute(\n",
    "        input_ids,\n",
    "        base_ids,\n",
    "        target=true_class,\n",
    "        return_convergence_delta=True)\n",
    "\n",
    "    # Summarize and z-score normalize the attributions\n",
    "    # for each representation in `layer`:\n",
    "    scores = attrs.sum(dim=-1).squeeze(0)\n",
    "    scores = (scores - scores.mean()) / scores.norm()\n",
    "\n",
    "    # Intuitive tokens to help with analysis:\n",
    "    raw_input = hf_tokenizer.convert_ids_to_tokens(input_ids.tolist()[0])\n",
    "    # RoBERTa-specific clean-up:\n",
    "    raw_input = [x.strip(\"Ġ\") for x in raw_input]\n",
    "\n",
    "    # Predictions for comparisons:\n",
    "    pred_probs = hf_predict_one_proba(text)\n",
    "    pred_class = pred_probs.argmax()\n",
    "\n",
    "    score_vis = viz.VisualizationDataRecord(\n",
    "        word_attributions=scores,\n",
    "        pred_prob=pred_probs.max(),\n",
    "        pred_class=pred_class,\n",
    "        true_class=true_class,\n",
    "        attr_class=None,\n",
    "        attr_score=attrs.sum(),\n",
    "        raw_input_ids=raw_input,\n",
    "        convergence_score=delta)\n",
    "\n",
    "    return score_vis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table width: 100%><div style=\"border-top: 1px solid; margin-top: 5px;             padding-top: 5px; display: inline-block\"><b>Legend: </b><span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 60%)\"></span> Negative  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(0, 75%, 100%)\"></span> Neutral  <span style=\"display: inline-block; width: 10px; height: 10px;                 border: 1px solid; background-color:                 hsl(120, 75%, 50%)\"></span> Positive  </div><tr><th>True Label</th><th>Predicted Label</th><th>Attribution Label</th><th>Attribution Score</th><th>Word Importance</th><tr><td><text style=\"padding-right:2em\"><b>2</b></text></td><td><text style=\"padding-right:2em\"><b>2 (0.82)</b></text></td><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>1.98</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #s                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> They                    </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> said                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> would                    </font></mark><mark style=\"background-color: hsl(120, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> be                    </font></mark><mark style=\"background-color: hsl(120, 75%, 59%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> and                    </font></mark><mark style=\"background-color: hsl(0, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> they                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> right                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #/s                    </font></mark></td><tr><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.50)</b></text></td><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>0.07</b></text></td><td><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #s                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> They                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> said                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> would                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> be                    </font></mark><mark style=\"background-color: hsl(0, 75%, 72%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> and                    </font></mark><mark style=\"background-color: hsl(120, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> they                    </font></mark><mark style=\"background-color: hsl(120, 75%, 87%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 70%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> wrong                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(120, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #/s                    </font></mark></td><tr><tr><td><text style=\"padding-right:2em\"><b>2</b></text></td><td><text style=\"padding-right:2em\"><b>2 (0.76)</b></text></td><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>2.39</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #s                    </font></mark><mark style=\"background-color: hsl(0, 75%, 90%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> They                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> right                    </font></mark><mark style=\"background-color: hsl(0, 75%, 89%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> to                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> say                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> would                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> be                    </font></mark><mark style=\"background-color: hsl(120, 75%, 59%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #/s                    </font></mark></td><tr><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>0 (0.62)</b></text></td><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>3.46</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #s                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> They                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 64%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> wrong                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> to                    </font></mark><mark style=\"background-color: hsl(120, 75%, 92%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> say                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> would                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> be                    </font></mark><mark style=\"background-color: hsl(0, 75%, 81%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> great                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #/s                    </font></mark></td><tr><tr><td><text style=\"padding-right:2em\"><b>2</b></text></td><td><text style=\"padding-right:2em\"><b>2 (0.77)</b></text></td><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>1.78</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #s                    </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> They                    </font></mark><mark style=\"background-color: hsl(0, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> said                    </font></mark><mark style=\"background-color: hsl(0, 75%, 97%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(0, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> would                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> be                    </font></mark><mark style=\"background-color: hsl(120, 75%, 64%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> stellar                    </font></mark><mark style=\"background-color: hsl(0, 75%, 93%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 94%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> and                    </font></mark><mark style=\"background-color: hsl(0, 75%, 91%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> they                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 83%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> correct                    </font></mark><mark style=\"background-color: hsl(120, 75%, 95%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #/s                    </font></mark></td><tr><tr><td><text style=\"padding-right:2em\"><b>0</b></text></td><td><text style=\"padding-right:2em\"><b>1 (0.47)</b></text></td><td><text style=\"padding-right:2em\"><b>None</b></text></td><td><text style=\"padding-right:2em\"><b>1.17</b></text></td><td><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #s                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> They                    </font></mark><mark style=\"background-color: hsl(120, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> said                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> it                    </font></mark><mark style=\"background-color: hsl(0, 75%, 98%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> would                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> be                    </font></mark><mark style=\"background-color: hsl(0, 75%, 83%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> stellar                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> ,                    </font></mark><mark style=\"background-color: hsl(0, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> and                    </font></mark><mark style=\"background-color: hsl(120, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> they                    </font></mark><mark style=\"background-color: hsl(120, 75%, 96%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> were                    </font></mark><mark style=\"background-color: hsl(120, 75%, 58%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> incorrect                    </font></mark><mark style=\"background-color: hsl(0, 75%, 100%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> .                    </font></mark><mark style=\"background-color: hsl(0, 75%, 99%); opacity:1.0;                     line-height:1.75\"><font color=\"black\"> #/s                    </font></mark></td><tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "score_vis = hf_ig_analyses({\n",
    "    \"They said it would be great, and they were right.\": 2,\n",
    "    \"They said it would be great, and they were wrong.\": 0,\n",
    "    \"They were right to say it would be great.\": 2,\n",
    "    \"They were wrong to say it would be great.\": 0,\n",
    "    \"They said it would be stellar, and they were correct.\": 2,\n",
    "    \"They said it would be stellar, and they were incorrect.\": 0})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
